{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "71fe81dc",
   "metadata": {},
   "outputs": [],
   "source": [
    "from __future__ import absolute_import, division, print_function, unicode_literals\n",
    "import IPython.display as display\n",
    "from PIL import Image\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import time\n",
    "import os\n",
    "import absl.logging\n",
    "\n",
    "absl.logging.set_verbosity(absl.logging.ERROR)\n",
    "\n",
    "os.environ[\"TF_CPP_MIN_LOG_LEVEL\"] = \"2\"\n",
    "# 0 = all messages are logged (default behavior)\n",
    "# 1 = INFO messages are not printed\n",
    "# 2 = INFO and WARNING messages are not printed\n",
    "# 3 = INFO, WARNING, and ERROR messages are not printed\n",
    "\n",
    "# On Mac you may encounter an error related to OMP, this is a workaround, but slows down the code\n",
    "# https://github.com/dmlc/xgboost/issues/1715\n",
    "os.environ[\"KMP_DUPLICATE_LIB_OK\"] = \"True\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "132953ea",
   "metadata": {},
   "outputs": [],
   "source": [
    "import tensorflow as tf\n",
    "\n",
    "AUTOTUNE = tf.data.experimental.AUTOTUNE\n",
    "tf.__version__"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "43ca9a91",
   "metadata": {},
   "outputs": [],
   "source": [
    "if tf.test.gpu_device_name():\n",
    "    print(\"Default GPU Device:{}\".format(tf.test.gpu_device_name()))\n",
    "else:\n",
    "    print(\"Please install GPU version of TF if you have one.\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2ac37fbe",
   "metadata": {},
   "outputs": [],
   "source": [
    "from openbot import dataloader, data_augmentation, utils, train"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "05c28641",
   "metadata": {},
   "source": [
    "## Set train and test dirs"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "bc941e8e",
   "metadata": {},
   "source": [
    "Define the dataset directory and give it a name. Inside the dataset folder, there should be two folders, `train_data` and `test_data`. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "823ef8f3",
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset_dir = \"dataset\"\n",
    "dataset_name = \"openbot\"\n",
    "train_data_dir = os.path.join(dataset_dir, \"train_data\")\n",
    "test_data_dir = os.path.join(dataset_dir, \"test_data\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4388bbaa",
   "metadata": {},
   "source": [
    "## Hyperparameters\n",
    "<a id='hyperparameters'></a>"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "625eb2bd",
   "metadata": {},
   "source": [
    "You may have to tune the learning rate and batch size depending on your available compute resources and dataset. As a general rule of thumb, if you increase the batch size by a factor of n, you can increase the learning rate by a factor of sqrt(n). In order to accelerate training and make it more smooth, you should increase the batch size as much as possible. In our paper we used a batch size of 128. For debugging and hyperparamter tuning, you can set the number of epochs to a small value like 10. If you want to train a model which will achieve good performance, you should set it to 50 or more. In our paper we used 100."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a14c7cac",
   "metadata": {},
   "outputs": [],
   "source": [
    "params = train.Hyperparameters()\n",
    "\n",
    "params.MODEL = \"pilot_net\"  # choices: \"pilot_net\",\"cil_mobile\",\"cil_mobile_fast\",\"cil\"\n",
    "params.POLICY = \"autopilot\"  # choices: \"autopilot\",\"point_goal_nav\"\n",
    "params.TRAIN_BATCH_SIZE = 128\n",
    "params.TEST_BATCH_SIZE = 16\n",
    "params.LEARNING_RATE = 0.0003\n",
    "params.NUM_EPOCHS = 100\n",
    "params.BATCH_NORM = True  # use batch norm (recommended)\n",
    "params.FLIP_AUG = False  # flip image and controls as augmentation (only autopilot)\n",
    "params.CMD_AUG = False  # randomize high-level command as augmentation (only autopilot)\n",
    "params.USE_LAST = False  # resume training from last checkpoint\n",
    "params.WANDB = False\n",
    "# policy = \"autopilot\": images are expected to be 256x96 - no cropping required\n",
    "# policy = \"point_goal_nav\": images are expected to be 160x120 - cropping to 160x90\n",
    "params.IS_CROP = params.POLICY == \"point_goal_nav\""
   ]
  },
  {
   "cell_type": "markdown",
   "id": "82ac0929",
   "metadata": {},
   "source": [
    "## Pre-process the dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4b4494b4",
   "metadata": {},
   "outputs": [],
   "source": [
    "tr = train.Training(params)\n",
    "tr.dataset_name = dataset_name\n",
    "tr.train_data_dir = train_data_dir\n",
    "tr.test_data_dir = test_data_dir"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "abfc9b9b",
   "metadata": {},
   "source": [
    "Running this for the first time will take some time. This code will match image frames to the controls (labels) and indicator signals (commands).  By default, data samples where the vehicle was stationary will be removed. If this is not desired, you need to set `tr.remove_zeros = False`. If you have made any changes to the sensor files, changed `remove_zeros` or moved your dataset to a new directory, you need to set `tr.redo_matching = True`. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e9d78141",
   "metadata": {},
   "outputs": [],
   "source": [
    "tr.redo_matching = False\n",
    "tr.remove_zeros = True\n",
    "train.process_data(tr)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c531aecf",
   "metadata": {},
   "outputs": [],
   "source": [
    "import threading\n",
    "\n",
    "\n",
    "def broadcast(event, payload=None):\n",
    "    print(event, payload)\n",
    "\n",
    "\n",
    "event = threading.Event()\n",
    "my_callback = train.MyCallback(broadcast, event)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2fc670c9",
   "metadata": {},
   "source": [
    "In the next step, you can convert your dataset to a tfrecord, a data format optimized for training. You can skip this step if you already created a tfrecord before or if you want to train using the files directly. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "59336936",
   "metadata": {},
   "outputs": [],
   "source": [
    "train.create_tfrecord(my_callback, policy=tr.hyperparameters.POLICY)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b0a4a1cf",
   "metadata": {},
   "source": [
    "## Load the dataset"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a51ee383",
   "metadata": {},
   "source": [
    "If you did not create a tfrecord and want to load and buffer files from disk directly, set `no_tf_record = True`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bc2d2068",
   "metadata": {},
   "outputs": [],
   "source": [
    "no_tf_record = False"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "082c90bf",
   "metadata": {},
   "outputs": [],
   "source": [
    "if no_tf_record:\n",
    "    tr.train_data_dir = train_data_dir\n",
    "    tr.test_data_dir = test_data_dir\n",
    "    train.load_data(tr, verbose=0)\n",
    "else:\n",
    "    tr.train_data_dir = os.path.join(dataset_dir, \"tfrecords/train.tfrec\")\n",
    "    tr.test_data_dir = os.path.join(dataset_dir, \"tfrecords/test.tfrec\")\n",
    "    train.load_tfrecord(tr, verbose=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "edf34f3f",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Select interactive backend to show inline\n",
    "%matplotlib inline\n",
    "utils.show_batch(dataset=tr.train_ds, policy=tr.hyperparameters.POLICY, model=None)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0a5d0f77",
   "metadata": {},
   "source": [
    "## Training"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "71034420",
   "metadata": {},
   "source": [
    "The number of epochs is proportional to the training time. One epoch means going through the complete dataset once. Increasing `NUM_EPOCHS` will mean longer training time, but generally leads to better performance. To get familiar with the code it can be set to small values like `5` or `10`. To train a model that performs well, it should be set to values between `50` and `200`. Setting `USE_LAST` to `true` will resume the training from the last checkpoint. The default values are `NUM_EPOCHS = 100` and `USE_LAST = False`. They are set in [Hyperparameters](#hyperparameters)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "894cd54f",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Create the directory if it doesn't exist\n",
    "os.makedirs(os.path.join(os.getcwd(), \"models\", tr.model_name), exist_ok=True)\n\n",
    "# params.NUM_EPOCHS = 200\n",
    "# params.USE_LAST = True\n",
    "train.do_training(tr, my_callback, verbose=1)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0cfd4aac",
   "metadata": {},
   "source": [
    "## Evaluation"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f623b65f",
   "metadata": {},
   "source": [
    "The loss and mean absolute error should decrease. This indicates that the model is fitting the data well. The custom metrics (direction and angle) should go towards 1. These provide some additional insight to the training progress. The direction metric measures weather or not predictions are in the same direction as the labels. Similarly the angle metric measures if the prediction is within a small angle of the labels. The intuition is that driving in the right direction with the correct steering angle is most critical part for good final performance."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4867aaa7",
   "metadata": {},
   "source": [
    "### Plot metrics"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8ca0b291",
   "metadata": {},
   "outputs": [],
   "source": [
    "x = np.arange(tr.INITIAL_EPOCH + 1, tr.history.params[\"epochs\"] + 1, 1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "efd65e1a",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.figure().gca().xaxis.get_major_locator().set_params(integer=True)\n",
    "plt.plot(x, tr.history.history[\"loss\"], label=\"loss\")\n",
    "plt.plot(x, tr.history.history[\"val_loss\"], label=\"val_loss\")\n",
    "plt.xlabel(\"Epoch\")\n",
    "plt.ylabel(\"Loss\")\n",
    "plt.legend(loc=\"upper right\")\n",
    "plt.savefig(os.path.join(tr.log_path, \"loss.png\"), bbox_inches=\"tight\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "97e98f63",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.figure().gca().xaxis.get_major_locator().set_params(integer=True)\n",
    "plt.plot(x, tr.history.history[\"mean_absolute_error\"], label=\"mean_absolute_error\")\n",
    "plt.plot(\n",
    "    x, tr.history.history[\"val_mean_absolute_error\"], label=\"val_mean_absolute_error\"\n",
    ")\n",
    "plt.xlabel(\"Epoch\")\n",
    "plt.ylabel(\"Mean Absolute Error\")\n",
    "plt.legend(loc=\"upper right\")\n",
    "plt.savefig(os.path.join(tr.log_path, \"error.png\"), bbox_inches=\"tight\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "aa1752d7",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.figure().gca().xaxis.get_major_locator().set_params(integer=True)\n",
    "plt.plot(x, tr.history.history[\"direction_metric\"], label=\"direction_metric\")\n",
    "plt.plot(x, tr.history.history[\"val_direction_metric\"], label=\"val_direction_metric\")\n",
    "plt.xlabel(\"Epoch\")\n",
    "plt.ylabel(\"Direction Metric\")\n",
    "plt.legend(loc=\"lower right\")\n",
    "plt.savefig(os.path.join(tr.log_path, \"direction.png\"), bbox_inches=\"tight\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9de04eec",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.figure().gca().xaxis.get_major_locator().set_params(integer=True)\n",
    "plt.plot(x, tr.history.history[\"angle_metric\"], label=\"angle_metric\")\n",
    "plt.plot(x, tr.history.history[\"val_angle_metric\"], label=\"val_angle_metric\")\n",
    "plt.xlabel(\"Epoch\")\n",
    "plt.ylabel(\"Angle Metric\")\n",
    "plt.legend(loc=\"lower right\")\n",
    "plt.savefig(os.path.join(tr.log_path, \"angle.png\"), bbox_inches=\"tight\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e72c1de3",
   "metadata": {},
   "source": [
    "### Save tf lite models for best train, best val and last checkpoint"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "92a1ef62",
   "metadata": {},
   "outputs": [],
   "source": [
    "best_train_checkpoint = \"cp-best-train.ckpt\"\n",
    "best_train_tflite = utils.generate_tflite(tr.checkpoint_path, best_train_checkpoint)\n",
    "utils.save_tflite(best_train_tflite, tr.checkpoint_path, \"best-train\")\n",
    "best_train_index = np.argmin(np.array(tr.history.history[\"loss\"]))\n",
    "print(\n",
    "    \"Best Train Checkpoint (epoch %s) - angle: %.4f, val_angle: %.4f, direction: %.4f, val_direction: %.4f\"\n",
    "    % (\n",
    "        best_train_index,\n",
    "        tr.history.history[\"angle_metric\"][best_train_index],\n",
    "        tr.history.history[\"val_angle_metric\"][best_train_index],\n",
    "        tr.history.history[\"direction_metric\"][best_train_index],\n",
    "        tr.history.history[\"val_direction_metric\"][best_train_index],\n",
    "    )\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bb605b90",
   "metadata": {},
   "outputs": [],
   "source": [
    "best_val_checkpoint = \"cp-best-val.ckpt\"\n",
    "best_val_tflite = utils.generate_tflite(tr.checkpoint_path, best_val_checkpoint)\n",
    "utils.save_tflite(best_val_tflite, tr.checkpoint_path, \"best\")\n",
    "utils.save_tflite(best_val_tflite, tr.checkpoint_path, \"best-val\")\n",
    "best_val_index = np.argmin(np.array(tr.history.history[\"val_loss\"]))\n",
    "print(\n",
    "    \"Best Val Checkpoint (epoch %s) - angle: %.4f, val_angle: %.4f, direction: %.4f, val_direction: %.4f\"\n",
    "    % (\n",
    "        best_val_index,\n",
    "        tr.history.history[\"angle_metric\"][best_val_index],\n",
    "        tr.history.history[\"val_angle_metric\"][best_val_index],\n",
    "        tr.history.history[\"direction_metric\"][best_val_index],\n",
    "        tr.history.history[\"val_direction_metric\"][best_val_index],\n",
    "    )\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0e2b1644",
   "metadata": {},
   "outputs": [],
   "source": [
    "last_checkpoint = \"cp-last.ckpt\"\n",
    "last_tflite = utils.generate_tflite(tr.checkpoint_path, last_checkpoint)\n",
    "utils.save_tflite(last_tflite, tr.checkpoint_path, \"last\")\n",
    "print(\n",
    "    \"Last Checkpoint - angle: %.4f, val_angle: %.4f, direction: %.4f, val_direction: %.4f\"\n",
    "    % (\n",
    "        tr.history.history[\"angle_metric\"][-1],\n",
    "        tr.history.history[\"val_angle_metric\"][-1],\n",
    "        tr.history.history[\"direction_metric\"][-1],\n",
    "        tr.history.history[\"val_direction_metric\"][-1],\n",
    "    )\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d0c57018",
   "metadata": {},
   "source": [
    "### Evaluate the best model (train loss) on the training set"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "04a03c6d",
   "metadata": {},
   "outputs": [],
   "source": [
    "best_train_model = utils.load_model(\n",
    "    os.path.join(tr.checkpoint_path, best_train_checkpoint),\n",
    "    tr.loss_fn,\n",
    "    tr.metric_list,\n",
    "    tr.custom_objects,\n",
    ")\n",
    "loss, mae, direction, angle = best_train_model.evaluate(\n",
    "    tr.train_ds,\n",
    "    steps=tr.image_count_train / tr.hyperparameters.TRAIN_BATCH_SIZE,\n",
    "    verbose=1,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "92462709",
   "metadata": {},
   "outputs": [],
   "source": [
    "utils.show_batch(\n",
    "    dataset=tr.train_ds, policy=tr.hyperparameters.POLICY, model=best_train_model\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "10a078b1",
   "metadata": {},
   "outputs": [],
   "source": [
    "utils.compare_tf_tflite(best_train_model, best_train_tflite)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b2ad2605",
   "metadata": {},
   "source": [
    "### Evaluate the best model (val loss) on the validation set"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e2ac5b20",
   "metadata": {},
   "outputs": [],
   "source": [
    "best_val_model = utils.load_model(\n",
    "    os.path.join(tr.checkpoint_path, best_val_checkpoint),\n",
    "    tr.loss_fn,\n",
    "    tr.metric_list,\n",
    "    tr.custom_objects,\n",
    ")\n",
    "loss, mae, direction, angle = best_val_model.evaluate(\n",
    "    tr.test_ds,\n",
    "    steps=tr.image_count_test / tr.hyperparameters.TEST_BATCH_SIZE,\n",
    "    verbose=1,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d8e01e77",
   "metadata": {},
   "outputs": [],
   "source": [
    "utils.show_batch(\n",
    "    dataset=tr.test_ds, policy=tr.hyperparameters.POLICY, model=best_val_model\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "daf9dbac",
   "metadata": {},
   "outputs": [],
   "source": [
    "utils.compare_tf_tflite(best_val_model, best_val_tflite)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "29c9d606",
   "metadata": {},
   "source": [
    "## Save the notebook as HTML"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3c47915f",
   "metadata": {},
   "outputs": [],
   "source": [
    "utils.save_notebook()\n",
    "current_file = \"policy_learning.ipynb\"\n",
    "output_file = os.path.join(tr.log_path, \"notebook.html\")\n",
    "utils.output_HTML(current_file, output_file)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fa902ebd",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.16"
  },
  "vscode": {
   "interpreter": {
    "hash": "31f2aee4e71d21fbe5cf8b01ff0e069b9275f58929596ceb00d14d90e3e16cd6"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
